Back to Article
Article Notebook
Download Source

Diptera wing classification using Topological Data Analysis

Authors
Affiliation

Guilherme Vituri F. Pinto

Universidade Estadual Paulista

Sergio Ura

Northon

Published

February 24, 2026

Abstract

We apply tools from Topological Data Analysis (TDA) to classify Diptera families based on wing venation patterns. Using multiple filtration strategies — Vietoris-Rips on point clouds, directional height filtrations (8 directions), radial filtrations, Euclidean Distance Transform filtrations, and grayscale sublevel-set (cubical) filtrations on wing images — we extract both H0 and H1 topological features (persistence images, Betti curves, persistence landscapes and summary statistics) and compare distance-based and feature-based classifiers via leave-one-out cross-validation. Feature selection via Random Forest importance and nested LOOCV provide honest, unbiased accuracy estimates.

Keywords

Topological Data Analysis, Persistent homology, Diptera classification, Wing venation

In [2]:
using TDAfly, TDAfly.Preprocessing, TDAfly.TDA, TDAfly.Analysis
using Images: mosaicview, Gray
using Plots: plot, display, heatmap, scatter, bar
using StatsPlots: boxplot
using PersistenceDiagrams
using PersistenceDiagrams: BettiCurve, Landscape, PersistenceImage
using DataFrames
using Distances: euclidean
using LIBSVM
using StatsBase: mean
[ Info: Precompiling TDAfly [5ee89b08-7b45-4496-b6d5-71693e11a84c] (cache misses: wrong dep version loaded (2))



SYSTEM: caught exception of type :MethodError while trying to print a failed Task notice; giving up



SYSTEM: caught exception of type :MethodError while trying to print a failed Task notice; giving up



SYSTEM: caught exception of type :MethodError while trying to print a failed Task notice; giving up

[ Info: Precompiling PolynomialsMakieExt [6a4b1961-d857-5aa3-b7f6-fc7c46de29bb] (cache misses: wrong dep version loaded (2))



SYSTEM: caught exception of type :MethodError while trying to print a failed Task notice; giving up

[ Info: Precompiling StatsPlots [f3b207a7-027a-5e70-b257-86293d7955fd] (cache misses: wrong dep version loaded (6))



SYSTEM: caught exception of type :MethodError while trying to print a failed Task notice; giving up

1 Introduction

The order Diptera (true flies) comprises over 150,000 described species across more than 150 families. Wing venation patterns are a classical diagnostic character in Diptera systematics: the arrangement, branching and connectivity of veins varies markedly across families and provides a natural morphological signature.

In this work, we apply Topological Data Analysis (TDA) to the problem of classifying Diptera families from wing images. TDA provides a framework for extracting shape descriptors that are robust to continuous deformations — exactly the kind of invariance desirable when comparing biological structures that vary in scale, orientation and minor deformations across individuals.

We employ five complementary filtration strategies:

  1. Vietoris-Rips filtration on point-cloud samples of wing silhouettes
  2. Directional height filtrations (8 directions) that sweep across the wing along different axes
  3. Radial filtration from the wing centroid to the periphery
  4. Euclidean Distance Transform (EDT) filtration capturing vein thickness hierarchy
  5. Grayscale sublevel-set (cubical) filtration on the raw wing image

For each filtration, we compute both H0 (connected components / vein branching) and H1 (loops / enclosed cells) persistence, then vectorize into feature representations (persistence images, Betti curves, persistence landscapes) and feed into classifiers. Feature selection and nested cross-validation provide honest accuracy estimates.

2 Methods

2.1 Data loading and preprocessing

All images are in the images/processed directory. For each image, we load it, apply a Gaussian blur (to close small gaps in the wing membrane and keep it connected), crop to the bounding box, and resize to 150 pixels of height.

In [3]:
all_paths = readdir("images/processed", join = true)
all_filenames = basename.(all_paths) .|> (x -> replace(x, ".png" => ""))

function extract_family(name)
    family_raw = lowercase(split(name, r"[\s\-]")[1])
    if family_raw in ("bibionidae", "biobionidae")
        return "Bibionidae"
    elseif family_raw in ("sciaridae", "scaridae")
        return "Sciaridae"
    elseif family_raw == "simulidae"
        return "Simuliidae"
    else
        return titlecase(family_raw)
    end
end

function canonical_id(name)
    family = extract_family(name)
    parts = split(name, r"[\s\-]")
    number = parts[end]
    "$(family)-$(number)"
end

# Deduplicate (space vs hyphen variants of the same file)
seen = Set{String}()
keep_idx = Int[]
for (i, fname) in enumerate(all_filenames)
    cid = canonical_id(fname)
    if !(cid in seen)
        push!(seen, cid)
        push!(keep_idx, i)
    end
end

paths = all_paths[keep_idx]
species = all_filenames[keep_idx]
families = extract_family.(species)

individuals = map(species) do specie
    parts = split(specie, r"[\s\-]")
    string(extract_family(specie)[1]) * "-" * parts[end]
end

println("Total images after deduplication: $(length(paths))")
println("Families: ", sort(unique(families)))
println("\nSamples per family:")
for f in sort(unique(families))
    println("  $(f): $(count(==(f), families))")
end
Total images after deduplication: 72
Families: ["Asilidae", "Bibionidae", "Ceratopogonidae", "Chironomidae", "Pelecorhynchidae", "Rhagionidae", "Sciaridae", "Simuliidae", "Tabanidae", "Tipulidae"]

Samples per family:
  Asilidae: 8
  Bibionidae: 6
  Ceratopogonidae: 8
  Chironomidae: 8
  Pelecorhynchidae: 2
  Rhagionidae: 4
  Sciaridae: 6
  Simuliidae: 7
  Tabanidae: 11
  Tipulidae: 12

2.1.1 Excluding small families

Families with fewer than 3 samples (e.g. Pelecorhynchidae with \(n=2\)) can distort cross-validation results—a single misclassification changes accuracy by 50%. We provide a filtered version and run the analysis both ways.

In [4]:
MIN_FAMILY_SIZE = 3
family_counts = Dict(f => count(==(f), families) for f in unique(families))
small_families = [f for (f, c) in family_counts if c < MIN_FAMILY_SIZE]

if !isempty(small_families)
    println("Families with < $MIN_FAMILY_SIZE samples (excluded from filtered analysis):")
    for f in sort(small_families)
        println("  $(f): $(family_counts[f]) samples")
    end
end

# Build filtered indices
keep_filtered = [i for i in eachindex(families) if family_counts[families[i]] >= MIN_FAMILY_SIZE]
paths_filtered = paths[keep_filtered]
species_filtered = species[keep_filtered]
families_filtered = families[keep_filtered]
individuals_filtered = individuals[keep_filtered]

println("\nFiltered dataset: $(length(keep_filtered)) samples, $(length(unique(families_filtered))) families")
Families with < 3 samples (excluded from filtered analysis):
  Pelecorhynchidae: 2 samples

Filtered dataset: 70 samples, 9 families
In [5]:
wings = load_wing.(paths, blur = 1.3)
Xs = map(wings) do w
    image_to_r2(w; ensure_connected = true, connectivity = 8)
end;
In [6]:
mosaicview(wings, ncol = 6, fillvalue = 1)

2.2 Example: forcing connectivity on 5 wings

The chunk below selects 5 wings (prioritizing those with the largest number of disconnected components before correction), then compares the binary pixel set before and after connect_pixel_components.

In [7]:
threshold_conn = 0.2
conn = 8

component_count_before = map(wings) do w
    ids0 = findall_ids(>(threshold_conn), image_to_array(w))
    length(pixel_components(ids0; connectivity = conn))
end

demo_idx = sortperm(component_count_before, rev = true)[1:min(5, length(wings))]

function ids_to_mask(ids)
    isempty(ids) && return zeros(Float32, 1, 1)
    xs = first.(ids)
    ys = last.(ids)
    M = zeros(Float32, maximum(xs), maximum(ys))
    for p in ids
        M[p[1], p[2]] = 1f0
    end
    M
end

demo_connectivity_df = DataFrame(
    sample = String[],
    n_components_before = Int[],
    n_components_after = Int[],
    n_pixels_before = Int[],
    n_pixels_after = Int[],
)

panel_plots = Any[]
for idx in demo_idx
    ids_before = findall_ids(>(threshold_conn), image_to_array(wings[idx]))
    ids_after = connect_pixel_components(ids_before; connectivity = conn)

    n_before = length(pixel_components(ids_before; connectivity = conn))
    n_after = length(pixel_components(ids_after; connectivity = conn))

    push!(demo_connectivity_df, (
        species[idx],
        n_before,
        n_after,
        length(ids_before),
        length(ids_after),
    ))

    M_before = ids_to_mask(ids_before)
    M_after = ids_to_mask(ids_after)

    p_before = heatmap(
        M_before[end:-1:1, :],
        color = :grays,
        colorbar = false,
        legend = false,
        aspect_ratio = :equal,
        xticks = false,
        yticks = false,
        title = "Before: $(species[idx])\ncomponents = $(n_before)",
    )

    p_after = heatmap(
        M_after[end:-1:1, :],
        color = :grays,
        colorbar = false,
        legend = false,
        aspect_ratio = :equal,
        xticks = false,
        yticks = false,
        title = "After: $(species[idx])\ncomponents = $(n_after)",
    )

    push!(panel_plots, p_before)
    push!(panel_plots, p_after)
end

plot(panel_plots..., layout = (length(demo_idx), 2), size = (900, 260 * length(demo_idx)))
In [8]:
demo_connectivity_df
5×5 DataFrame
Row sample n_components_before n_components_after n_pixels_before n_pixels_after
String Int64 Int64 Int64 Int64
1 biobionidae 11 265 1 6368 6864
2 chironomidae 17 220 1 14610 15077
3 biobionidae 9 216 1 6026 6545
4 chironomidae 13 206 1 12701 13193
5 chironomidae 16 194 1 14862 15349

3 Topological feature extraction

We now compute persistent homology using five filtration strategies. For the Vietoris-Rips filtration on connected point clouds, H0 is uninformative (single infinite bar), so we use only H1. However, for cubical filtrations (directional, radial, EDT, grayscale), H0 is highly informative — it captures when disconnected vein segments merge as the filtration parameter grows, directly encoding vein count and branching patterns. We therefore compute both H0 and H1 for all cubical-based filtrations.

What is persistent homology?

Persistent homology is the main tool of TDA. Given a shape or dataset, it tracks how topological features — connected components (dimension 0), loops (dimension 1), voids (dimension 2), etc. — appear and disappear as we “grow” the shape through a filtration parameter. Each feature has a birth time (when it appears) and a death time (when it gets filled in). The collection of all (birth, death) pairs is called a persistence diagram. Features with long lifetimes (high persistence = death \(-\) birth) represent genuine topological structure, while short-lived features are typically noise.

3.1 Strategy 1: Vietoris-Rips filtration on point clouds

Vietoris-Rips filtration

Given a set of points in \(\mathbb{R}^n\), the Vietoris-Rips complex at scale \(\varepsilon\) connects any subset of points that are pairwise within distance \(\varepsilon\). As \(\varepsilon\) increases from 0, we obtain a nested sequence of simplicial complexes — the Rips filtration. This is the most common filtration in TDA for point-cloud data. It is computationally expensive (since it must consider all pairwise distances), which is why we subsample the point clouds.

We sample 750 points from each wing silhouette using farthest-point sampling (which ensures good coverage of the shape), then compute 1-dimensional Rips persistence:

In [9]:
samples = Vector{Any}(undef, length(Xs))
Threads.@threads for i in eachindex(Xs)
    samples[i] = farthest_points_sample(Xs[i], 750)
end
In [10]:
pds_rips = @showprogress map(samples) do s
    rips_pd_1d(s, cutoff = 5, threshold = 200)
end;
Progress:   3%|█▏                                       |  ETA: 0:01:37

Progress:  18%|███████▍                                 |  ETA: 0:00:19

Progress:  32%|█████████████▏                           |  ETA: 0:00:12

Progress:  44%|██████████████████▎                      |  ETA: 0:00:09

Progress:  61%|█████████████████████████                |  ETA: 0:00:05

Progress:  74%|██████████████████████████████▏          |  ETA: 0:00:04

Progress:  88%|███████████████████████████████████▉     |  ETA: 0:00:02

Progress: 100%|█████████████████████████████████████████| Time: 0:00:13
In [11]:
wing_arrays = [convert(Array{Float64}, w) for w in wings]
72-element Vector{Matrix{Float64}}:
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 ⋮
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]

3.2 Strategy 2: Directional height filtrations

Directional (height) filtrations

A height filtration sweeps a hyperplane across the shape in a chosen direction and tracks topology as the “visible” region grows. For a direction vector \(v\), we assign each foreground pixel the value \(\langle (i,j), v \rangle\) (its projection onto \(v\)), then compute sublevel-set persistence. Different directions capture different geometric aspects: a horizontal sweep detects how vein loops are arranged from base to tip, a vertical sweep captures dorsal-ventral structure, and diagonal sweeps capture oblique patterns. Using multiple directions enriches the topological signature.

We compute persistence along eight directions (every 22.5°) to capture finer angular structure of vein branching, including oblique vein angles missed by 4 directions. For each direction, we extract both H0 (connected component merging = vein branching) and H1 (loop formation):

In [12]:
angles = range(0, π, length=9)[1:8]
directions = [[sin(θ), cos(θ)] for θ in angles]
direction_names = ["Dir_$(round(Int, rad2deg(θ)))°" for θ in angles]

println("Using $(length(directions)) directions:")
for (name, dir) in zip(direction_names, directions)
    println("  $name: $dir")
end

# H1 persistence (loops) — as before, but expanded to 8 directions
pds_directional = Dict{String, Vector}()
for (dir, name) in zip(directions, direction_names)
    pds_directional[name] = @showprogress "$name H1" map(wing_arrays) do A
        directional_pd_1d(A, dir)
    end
end

# H0 persistence (connected component merging = vein branching patterns)
pds_directional_h0 = Dict{String, Vector}()
for (dir, name) in zip(directions, direction_names)
    pds_directional_h0[name] = @showprogress "$name H0" map(wing_arrays) do A
        directional_pd_0d(A, dir)
    end
end;
Using 8 directions:
  Dir_0°: [0.0, 1.0]
  Dir_22°: [0.3826834323650898, 0.9238795325112867]
  Dir_45°: [0.7071067811865475, 0.7071067811865476]
  Dir_68°: [0.9238795325112867, 0.38268343236508984]
  Dir_90°: [1.0, 6.123233995736766e-17]
  Dir_112°: [0.9238795325112867, -0.3826834323650897]
  Dir_135°: [0.7071067811865476, -0.7071067811865475]
  Dir_158°: [0.3826834323650899, -0.9238795325112867]
Dir_0° H1   3%|█▏                                        |  ETA: 0:01:01

Dir_0° H1  44%|██████████████████▋                       |  ETA: 0:00:03

Dir_0° H1  64%|██████████████████████████▉               |  ETA: 0:00:02

Dir_0° H1 100%|██████████████████████████████████████████| Time: 0:00:03


Dir_22° H1   7%|██▉                                      |  ETA: 0:00:01

Dir_22° H1  26%|██████████▉                              |  ETA: 0:00:01

Dir_22° H1  60%|████████████████████████▌                |  ETA: 0:00:01

Dir_22° H1 100%|█████████████████████████████████████████| Time: 0:00:01


Dir_45° H1   7%|██▉                                      |  ETA: 0:00:02

Dir_45° H1  38%|███████████████▍                         |  ETA: 0:00:01

Dir_45° H1  64%|██████████████████████████▎              |  ETA: 0:00:01

Dir_45° H1 100%|█████████████████████████████████████████| Time: 0:00:01


Dir_68° H1   7%|██▉                                      |  ETA: 0:00:01

Dir_68° H1  29%|████████████                             |  ETA: 0:00:01

Dir_68° H1 100%|█████████████████████████████████████████| Time: 0:00:01


Dir_90° H1   7%|██▉                                      |  ETA: 0:00:01

Dir_90° H1  29%|████████████                             |  ETA: 0:00:01

Dir_90° H1  50%|████████████████████▌                    |  ETA: 0:00:01

Dir_90° H1 100%|█████████████████████████████████████████| Time: 0:00:01


Dir_112° H1   6%|██▎                                     |  ETA: 0:00:02

Dir_112° H1  25%|██████████                              |  ETA: 0:00:02

Dir_112° H1  58%|███████████████████████▍                |  ETA: 0:00:01

Dir_112° H1  99%|███████████████████████████████████████▌|  ETA: 0:00:00

Dir_112° H1 100%|████████████████████████████████████████| Time: 0:00:01


Dir_135° H1   8%|███▍                                    |  ETA: 0:00:01

Dir_135° H1  61%|████████████████████████▌               |  ETA: 0:00:01

Dir_135° H1 100%|████████████████████████████████████████| Time: 0:00:01


Dir_158° H1   7%|██▊                                     |  ETA: 0:00:01

Dir_158° H1  24%|█████████▌                              |  ETA: 0:00:01

Dir_158° H1  40%|████████████████▏                       |  ETA: 0:00:01

Dir_158° H1  64%|█████████████████████████▌              |  ETA: 0:00:01

Dir_158° H1 100%|████████████████████████████████████████| Time: 0:00:01


Dir_0° H0   4%|█▊                                        |  ETA: 0:00:02

Dir_0° H0  32%|█████████████▍                            |  ETA: 0:00:01

Dir_0° H0 100%|██████████████████████████████████████████| Time: 0:00:01


Dir_22° H0   8%|███▍                                     |  ETA: 0:00:01

Dir_22° H0  53%|█████████████████████▋                   |  ETA: 0:00:01

Dir_22° H0  75%|██████████████████████████████▊          |  ETA: 0:00:00

Dir_22° H0  92%|█████████████████████████████████████▋   |  ETA: 0:00:00

Dir_22° H0 100%|█████████████████████████████████████████| Time: 0:00:01


Dir_45° H0   7%|██▉                                      |  ETA: 0:00:02

Dir_45° H0  26%|██████████▉                              |  ETA: 0:00:01

Dir_45° H0  75%|██████████████████████████████▊          |  ETA: 0:00:00

Dir_45° H0 100%|█████████████████████████████████████████| Time: 0:00:01


Dir_68° H0   7%|██▉                                      |  ETA: 0:00:01

Dir_68° H0  29%|████████████                             |  ETA: 0:00:01

Dir_68° H0 100%|█████████████████████████████████████████| Time: 0:00:01


Dir_90° H0   7%|██▉                                      |  ETA: 0:00:01

Dir_90° H0  31%|████████████▌                            |  ETA: 0:00:01

Dir_90° H0  79%|████████████████████████████████▌        |  ETA: 0:00:00

Dir_90° H0 100%|█████████████████████████████████████████| Time: 0:00:01


Dir_112° H0   7%|██▊                                     |  ETA: 0:00:02

Dir_112° H0  29%|███████████▋                            |  ETA: 0:00:01

Dir_112° H0  50%|████████████████████                    |  ETA: 0:00:01

Dir_112° H0 100%|████████████████████████████████████████| Time: 0:00:01


Dir_135° H0   7%|██▊                                     |  ETA: 0:00:01

Dir_135° H0  35%|█████████████▉                          |  ETA: 0:00:01

Dir_135° H0  65%|██████████████████████████▏             |  ETA: 0:00:00

Dir_135° H0 100%|████████████████████████████████████████| Time: 0:00:01


Dir_158° H0   8%|███▍                                    |  ETA: 0:00:01

Dir_158° H0  50%|████████████████████                    |  ETA: 0:00:01

Dir_158° H0  92%|████████████████████████████████████▋   |  ETA: 0:00:00

Dir_158° H0 100%|████████████████████████████████████████| Time: 0:00:01

3.3 Strategy 3: Radial filtration

Radial filtration

The radial filtration assigns each foreground pixel a value equal to its distance from the centroid of the wing. Sublevel-set persistence on this function captures how topological features (loops in the venation) are distributed from the center of the wing outward. This is complementary to the directional filtrations.

In [13]:
pds_radial = @showprogress "radial_pd_1d" map(wing_arrays) do A
    radial_pd_1d(A)
end;
radial_pd_1d   3%|█▏                                     |  ETA: 0:00:05

radial_pd_1d  19%|███████▋                               |  ETA: 0:00:02

radial_pd_1d  43%|████████████████▊                      |  ETA: 0:00:01

radial_pd_1d 100%|███████████████████████████████████████| Time: 0:00:01

We also compute H0 persistence for the radial filtration, capturing how disconnected vein segments merge as the radial sweep grows outward:

In [14]:
pds_radial_h0 = @showprogress "radial_pd_0d" map(wing_arrays) do A
    radial_pd_0d(A)
end;
radial_pd_0d   4%|█▋                                     |  ETA: 0:00:02

radial_pd_0d  33%|█████████████                          |  ETA: 0:00:01

radial_pd_0d  56%|█████████████████████▋                 |  ETA: 0:00:01

radial_pd_0d 100%|███████████████████████████████████████| Time: 0:00:01

3.4 Strategy 4: Euclidean Distance Transform (EDT) filtration

EDT filtration

The Euclidean Distance Transform assigns each foreground pixel the distance to the nearest background pixel. Thick veins get high EDT values. By negating the EDT as a filtration value, thick veins appear first in the sublevel-set filtration. This captures the vein thickness hierarchy — a diagnostic taxonomic character (e.g., Tabanidae have thickened costal and subcostal veins).

In [15]:
pds_edt_h1 = @showprogress "EDT H1" map(wing_arrays) do A
    edt_pd_1d(A)
end

pds_edt_h0 = @showprogress "EDT H0" map(wing_arrays) do A
    edt_pd_0d(A)
end;
EDT H1   3%|█▎                                           |  ETA: 0:00:11

EDT H1  14%|██████▎                                      |  ETA: 0:00:03

EDT H1  25%|███████████▎                                 |  ETA: 0:00:02

EDT H1  38%|████████████████▉                            |  ETA: 0:00:01

EDT H1  47%|█████████████████████▎                       |  ETA: 0:00:01

EDT H1  56%|█████████████████████████                    |  ETA: 0:00:01

EDT H1  65%|█████████████████████████████▍               |  ETA: 0:00:01

EDT H1  79%|███████████████████████████████████▋         |  ETA: 0:00:00

EDT H1  93%|█████████████████████████████████████████▉   |  ETA: 0:00:00

EDT H1 100%|█████████████████████████████████████████████| Time: 0:00:01


EDT H0   8%|███▊                                         |  ETA: 0:00:01

EDT H0  17%|███████▌                                     |  ETA: 0:00:01

EDT H0  25%|███████████▎                                 |  ETA: 0:00:01

EDT H0  36%|████████████████▎                            |  ETA: 0:00:01

EDT H0  47%|█████████████████████▎                       |  ETA: 0:00:01

EDT H0  58%|██████████████████████████▎                  |  ETA: 0:00:00

EDT H0  68%|██████████████████████████████▋              |  ETA: 0:00:00

EDT H0  83%|█████████████████████████████████████▌       |  ETA: 0:00:00

EDT H0  97%|███████████████████████████████████████████▊ |  ETA: 0:00:00

EDT H0 100%|█████████████████████████████████████████████| Time: 0:00:01

3.5 Strategy 5: Cubical (grayscale sublevel-set) persistence

Grayscale sublevel-set persistence

The function cubical_pd computes sublevel-set persistence directly on the grayscale wing image (inverted so that dark veins have low filtration values). This captures the intensity landscape of the wing image without any thresholding, preserving information about semi-transparent wing membrane regions and vein intensity gradients.

In [16]:
pds_cubical = @showprogress "Cubical" map(wing_arrays) do A
    cubical_pd(A; dim_max=1)
end

pds_cubical_h0 = [pd[1] for pd in pds_cubical]
pds_cubical_h1 = [pd[2] for pd in pds_cubical];
Cubical   3%|█▎                                          |  ETA: 0:00:05

Cubical  43%|███████████████████                         |  ETA: 0:00:01

Cubical  90%|███████████████████████████████████████▊    |  ETA: 0:00:00

Cubical 100%|████████████████████████████████████████████| Time: 0:00:01

3.6 Persistence vectorization

Raw persistence diagrams live in a space that is not directly amenable to standard machine learning. We vectorize them using three approaches:

Persistence images

A persistence image is a stable, finite-dimensional representation of a persistence diagram. Each point \((b, d)\) is mapped to \((b, d - b)\) coordinates (birth vs persistence), weighted by a function that emphasizes long-lived features, then smoothed with a Gaussian kernel and discretized onto a grid. The result is a matrix (image) that can be treated as a feature vector. Persistence images are stable with respect to the Wasserstein distance and have proven effective in machine learning pipelines.

Betti curves

The Betti curve \(\beta_k(t)\) counts the number of \(k\)-dimensional features alive at filtration value \(t\). For dimension 1, it counts the number of loops present at each scale. Discretized over a grid, it produces a feature vector. Betti curves are simple, interpretable, and capture the “topological complexity” of the shape at each scale.

Persistence landscapes

A persistence landscape is a sequence of piecewise-linear functions derived from a persistence diagram. The \(k\)-th landscape \(\lambda_k\) is the \(k\)-th largest value of a collection of tent functions, one per interval. Landscapes live in a Banach space, which means we can compute means, perform hypothesis tests, and use them directly in statistical and machine learning methods. They provide a richer representation than Betti curves.

In [17]:
# Vectorize Rips persistence
PI_rips = PersistenceImage(pds_rips, size = (15, 15))
pi_rips = PI_rips.(pds_rips)

bc_rips = BettiCurve(pds_rips; length = 50)
betti_rips = bc_rips.(pds_rips)

land1_rips = Landscape(1, pds_rips; length = 50)
land2_rips = Landscape(2, pds_rips; length = 50)
land1_rips_vecs = land1_rips.(pds_rips)
land2_rips_vecs = land2_rips.(pds_rips);
In [18]:
# Vectorize directional persistence (H1)
pi_directional = Dict{String, Vector}()
betti_directional = Dict{String, Vector}()

for name in direction_names
    pds = pds_directional[name]
    PI_d = PersistenceImage(pds, size = (10, 10))
    pi_directional[name] = PI_d.(pds)

    bc_d = BettiCurve(pds; length = 30)
    betti_directional[name] = bc_d.(pds)
end

# Vectorize directional persistence (H0 — NEW)
pi_directional_h0 = Dict{String, Vector}()
betti_directional_h0 = Dict{String, Vector}()

for name in direction_names
    pds = pds_directional_h0[name]
    # Filter out infinite intervals for vectorization
    pds_finite = [filter(x -> isfinite(persistence(x)), pd) for pd in pds]
    if any(!isempty, pds_finite)
        PI_d0 = PersistenceImage(pds_finite, size = (10, 10))
        pi_directional_h0[name] = PI_d0.(pds_finite)
        bc_d0 = BettiCurve(pds_finite; length = 30)
        betti_directional_h0[name] = bc_d0.(pds_finite)
    else
        pi_directional_h0[name] = [zeros(10 * 10) for _ in pds]
        betti_directional_h0[name] = [zeros(30) for _ in pds]
    end
end

# Radial H1
PI_rad = PersistenceImage(pds_radial, size = (10, 10))
pi_radial = PI_rad.(pds_radial)

bc_rad = BettiCurve(pds_radial; length = 30)
betti_radial = bc_rad.(pds_radial)

# Radial H0 (NEW)
pds_radial_h0_finite = [filter(x -> isfinite(persistence(x)), pd) for pd in pds_radial_h0]
if any(!isempty, pds_radial_h0_finite)
    PI_rad_h0 = PersistenceImage(pds_radial_h0_finite, size = (10, 10))
    pi_radial_h0 = PI_rad_h0.(pds_radial_h0_finite)
    bc_rad_h0 = BettiCurve(pds_radial_h0_finite; length = 30)
    betti_radial_h0 = bc_rad_h0.(pds_radial_h0_finite)
else
    pi_radial_h0 = [zeros(10 * 10) for _ in pds_radial_h0]
    betti_radial_h0 = [zeros(30) for _ in pds_radial_h0]
end

# EDT H1 (NEW)
if any(!isempty, pds_edt_h1)
    PI_edt_h1 = PersistenceImage(pds_edt_h1, size = (10, 10))
    pi_edt_h1 = PI_edt_h1.(pds_edt_h1)
    bc_edt_h1 = BettiCurve(pds_edt_h1; length = 30)
    betti_edt_h1 = bc_edt_h1.(pds_edt_h1)
else
    pi_edt_h1 = [zeros(10 * 10) for _ in pds_edt_h1]
    betti_edt_h1 = [zeros(30) for _ in pds_edt_h1]
end

# EDT H0 (NEW)
pds_edt_h0_finite = [filter(x -> isfinite(persistence(x)), pd) for pd in pds_edt_h0]
if any(!isempty, pds_edt_h0_finite)
    PI_edt_h0 = PersistenceImage(pds_edt_h0_finite, size = (10, 10))
    pi_edt_h0 = PI_edt_h0.(pds_edt_h0_finite)
    bc_edt_h0 = BettiCurve(pds_edt_h0_finite; length = 30)
    betti_edt_h0 = bc_edt_h0.(pds_edt_h0_finite)
else
    pi_edt_h0 = [zeros(10 * 10) for _ in pds_edt_h0]
    betti_edt_h0 = [zeros(30) for _ in pds_edt_h0]
end

# Cubical H0 and H1 (NEW)
if any(!isempty, pds_cubical_h1)
    PI_cub_h1 = PersistenceImage(pds_cubical_h1, size = (10, 10))
    pi_cubical_h1 = PI_cub_h1.(pds_cubical_h1)
    bc_cub_h1 = BettiCurve(pds_cubical_h1; length = 30)
    betti_cubical_h1 = bc_cub_h1.(pds_cubical_h1)
else
    pi_cubical_h1 = [zeros(10 * 10) for _ in pds_cubical_h1]
    betti_cubical_h1 = [zeros(30) for _ in pds_cubical_h1]
end

pds_cubical_h0_finite = [filter(x -> isfinite(persistence(x)), pd) for pd in pds_cubical_h0]
if any(!isempty, pds_cubical_h0_finite)
    PI_cub_h0 = PersistenceImage(pds_cubical_h0_finite, size = (10, 10))
    pi_cubical_h0 = PI_cub_h0.(pds_cubical_h0_finite)
    bc_cub_h0 = BettiCurve(pds_cubical_h0_finite; length = 30)
    betti_cubical_h0 = bc_cub_h0.(pds_cubical_h0_finite)
else
    pi_cubical_h0 = [zeros(10 * 10) for _ in pds_cubical_h0]
    betti_cubical_h0 = [zeros(30) for _ in pds_cubical_h0]
end

println("Vectorization complete.")
println("  Directional: $(length(direction_names)) directions × H0 + H1")
println("  Radial: H0 + H1")
println("  EDT: H0 + H1")
println("  Cubical: H0 + H1");
Vectorization complete.
  Directional: 8 directions × H0 + H1
  Radial: H0 + H1
  EDT: H0 + H1
  Cubical: H0 + H1

3.7 Examples

Below are examples of 1-dimensional persistence diagrams from each filtration strategy for one specimen per family:

In [19]:
example_indices = [findfirst(==(f), families) for f in sort(unique(families))]

for i in example_indices
    pers_rips = persistence.(pds_rips[i])
    pers_dir = persistence.(pds_directional[direction_names[1]][i])
    pers_rad = persistence.(pds_radial[i])
    pers_edt = persistence.(pds_edt_h1[i])

    p1 = isempty(pers_rips) ? plot(title = "Rips H₁ (empty)") :
         bar(sort(pers_rips, rev = true), title = "Rips H₁", legend = false, ylabel = "persistence")
    p2 = isempty(pers_dir) ? plot(title = "Dir 0° H₁ (empty)") :
         bar(sort(pers_dir, rev = true), title = "Dir 0° H₁", legend = false, ylabel = "persistence")
    p3 = isempty(pers_rad) ? plot(title = "Radial H₁ (empty)") :
         bar(sort(pers_rad, rev = true), title = "Radial H₁", legend = false, ylabel = "persistence")
    p4 = isempty(pers_edt) ? plot(title = "EDT H₁ (empty)") :
         bar(sort(pers_edt, rev = true), title = "EDT H₁", legend = false, ylabel = "persistence")
    p5 = scatter(last.(samples[i]), first.(samples[i]),
                 aspect_ratio = :equal, markersize = 1, legend = false, title = "Point cloud")

    # H0 examples
    pers_dir_h0 = [persistence(x) for x in pds_directional_h0[direction_names[1]][i] if isfinite(persistence(x))]
    p6 = isempty(pers_dir_h0) ? plot(title = "Dir 0° H₀ (empty)") :
         bar(sort(pers_dir_h0, rev = true), title = "Dir 0° H₀", legend = false, ylabel = "persistence")

    p = plot(p1, p2, p3, p4, p5, p6, layout = (3, 2), size = (900, 900),
             plot_title = "$(families[i]) ($(individuals[i]))")
    display(p)
end;

3.8 Summary statistics

We extract summary statistics from each persistence diagram:

In [20]:
stat_names = ["count", "max_pers", "total_pers", "total_pers2",
              "q10", "q25", "median", "q75", "q90", "entropy", "std_pers"]

stats_rips = collect(hcat([pd_statistics(pd) for pd in pds_rips]...)')

# Directional H1 stats
stats_directional = Dict{String, Matrix}()
for name in direction_names
    stats_directional[name] = collect(hcat([pd_statistics(pd) for pd in pds_directional[name]]...)')
end

# Directional H0 stats (NEW)
stats_directional_h0 = Dict{String, Matrix}()
for name in direction_names
    stats_directional_h0[name] = collect(hcat([pd_statistics(pd) for pd in pds_directional_h0[name]]...)')
end

stats_radial = collect(hcat([pd_statistics(pd) for pd in pds_radial]...)')
stats_radial_h0 = collect(hcat([pd_statistics(pd) for pd in pds_radial_h0]...)')   # NEW

# EDT stats (NEW)
stats_edt_h1 = collect(hcat([pd_statistics(pd) for pd in pds_edt_h1]...)')
stats_edt_h0 = collect(hcat([pd_statistics(pd) for pd in pds_edt_h0]...)')

# Cubical stats (NEW)
stats_cubical_h1 = collect(hcat([pd_statistics(pd) for pd in pds_cubical_h1]...)')
stats_cubical_h0 = collect(hcat([pd_statistics(pd) for pd in pds_cubical_h0]...)')

# Combined stats from ALL filtrations (H1 only, for backward compatibility)
stats_all_h1 = hcat(stats_rips, stats_radial,
                    [stats_directional[name] for name in direction_names]...,
                    stats_edt_h1, stats_cubical_h1)

# Combined stats from ALL filtrations (H0 + H1 = comprehensive)
stats_all = hcat(
    stats_rips,
    stats_radial, stats_radial_h0,
    [stats_directional[name] for name in direction_names]...,
    [stats_directional_h0[name] for name in direction_names]...,
    stats_edt_h1, stats_edt_h0,
    stats_cubical_h1, stats_cubical_h0,
)

println("Statistics dimensions:")
println("  Rips H1: $(size(stats_rips))")
println("  All H1 only: $(size(stats_all_h1))")
println("  All H0+H1 (comprehensive): $(size(stats_all))")
Statistics dimensions:
  Rips H1: (72, 11)
  All H1 only: (72, 132)
  All H0+H1 (comprehensive): (72, 253)

3.8.1 Statistics comparison by family

In [21]:
stats_df = DataFrame(
    sample = individuals,
    family = families,
    n_intervals_rips = stats_rips[:, 1],
    max_pers_rips = stats_rips[:, 2],
    entropy_rips = stats_rips[:, 10],
    n_intervals_rad = stats_radial[:, 1],
    max_pers_rad = stats_radial[:, 2],
    entropy_rad = stats_radial[:, 10],
    n_intervals_edt_h1 = stats_edt_h1[:, 1],
    max_pers_edt_h1 = stats_edt_h1[:, 2],
    n_intervals_cub_h1 = stats_cubical_h1[:, 1],
    max_pers_cub_h1 = stats_cubical_h1[:, 2],
)

p1 = boxplot(stats_df.family, stats_df.n_intervals_rips,
             title = "Rips: # intervals", legend = false, ylabel = "count", xrotation = 45)
p2 = boxplot(stats_df.family, stats_df.max_pers_rips,
             title = "Rips: max persistence", legend = false, ylabel = "persistence", xrotation = 45)
p3 = boxplot(stats_df.family, stats_df.n_intervals_edt_h1,
             title = "EDT H₁: # intervals", legend = false, ylabel = "count", xrotation = 45)
p4 = boxplot(stats_df.family, stats_df.max_pers_edt_h1,
             title = "EDT H₁: max persistence", legend = false, ylabel = "persistence", xrotation = 45)
plot(p1, p2, p3, p4, layout = (2, 2), size = (1000, 700))

3.9 Decision tree on rich 1D persistence statistics

We now use the full 1-dimensional persistence statistics from all filtrations (stats_all) as features for a single decision tree classifier.

In [22]:
using DecisionTree
using Random: MersenneTwister

labels_tree = families
X_tree = sanitize_feature_matrix(stats_all)

tree_blocks_h1 = ["Rips_H1", "Radial_H1", [name * "_H1" for name in direction_names]..., "EDT_H1", "Cubical_H1"]
tree_blocks_h0 = ["Radial_H0", [name * "_H0" for name in direction_names]..., "EDT_H0", "Cubical_H0"]
tree_blocks = [tree_blocks_h1; tree_blocks_h0]
tree_feature_names = ["$(block)__$(stat)" for block in tree_blocks for stat in stat_names]

function loocv_decision_tree(X::Matrix, y::Vector{String};
                             max_depth::Int = 6,
                             min_samples_leaf::Int = 2,
                             min_samples_split::Int = 2,
                             rng_seed::Int = 20260223)
    Xclean = sanitize_feature_matrix(X)
    n = size(Xclean, 1)
    predictions = Vector{String}(undef, n)

    for i in 1:n
        train_idx = setdiff(1:n, i)
        X_train = Xclean[train_idx, :]
        y_train = y[train_idx]

        tree = DecisionTree.build_tree(
            y_train,
            X_train,
            size(X_train, 2),
            max_depth,
            min_samples_leaf,
            min_samples_split,
            0.0;
            loss = DecisionTree.util.gini,
            rng = MersenneTwister(rng_seed + i),
            impurity_importance = true
        )

        predictions[i] = DecisionTree.apply_tree(tree, Xclean[i, :])
    end

    (accuracy = mean(predictions .== y), predictions = predictions)
end

tree_results = DataFrame(
    max_depth = Int[],
    min_samples_leaf = Int[],
    min_samples_split = Int[],
    n_correct = Int[],
    accuracy = Float64[],
    balanced_accuracy = Float64[],
    macro_f1 = Float64[],
)

for max_depth in [3, 4, 5, 6, 8]
    for min_leaf in [1, 2, 3]
        for min_split in [2, 4]
            r = loocv_decision_tree(X_tree, labels_tree;
                                    max_depth = max_depth,
                                    min_samples_leaf = min_leaf,
                                    min_samples_split = min_split)
            m = classification_metrics(labels_tree, r.predictions)
            push!(tree_results, (
                max_depth,
                min_leaf,
                min_split,
                sum(r.predictions .== labels_tree),
                r.accuracy,
                m.balanced_accuracy,
                m.macro_f1
            ))
        end
    end
end

sort!(tree_results, :accuracy, rev = true)
first(tree_results, 10)
10×7 DataFrame
Row max_depth min_samples_leaf min_samples_split n_correct accuracy balanced_accuracy macro_f1
Int64 Int64 Int64 Int64 Float64 Float64 Float64
1 6 3 2 38 0.527778 0.459957 0.455618
2 6 3 4 38 0.527778 0.459957 0.455618
3 8 3 2 38 0.527778 0.459957 0.455618
4 8 3 4 38 0.527778 0.459957 0.455618
5 4 3 2 37 0.513889 0.443885 0.421777
6 4 3 4 37 0.513889 0.443885 0.421777
7 5 1 2 37 0.513889 0.451623 0.435796
8 5 3 2 37 0.513889 0.434957 0.424286
9 5 3 4 37 0.513889 0.434957 0.424286
10 6 1 4 37 0.513889 0.451623 0.445173
In [23]:
best_tree = tree_results[1, :]

tree_model = DecisionTree.build_tree(
    labels_tree,
    X_tree,
    size(X_tree, 2),
    best_tree.max_depth,
    best_tree.min_samples_leaf,
    best_tree.min_samples_split,
    0.0;
    loss = DecisionTree.util.gini,
    rng = MersenneTwister(20260223),
    impurity_importance = true
)

tree_importance = DecisionTree.impurity_importance(tree_model; normalize = true)

tree_importance_df = DataFrame(
    feature = tree_feature_names,
    importance = tree_importance
)
sort!(tree_importance_df, :importance, rev = true)

println("Best Decision Tree LOOCV: $(best_tree.n_correct)/$(length(labels_tree)) ($(round(best_tree.accuracy * 100, digits = 1))%)")
println("Balanced accuracy: $(round(best_tree.balanced_accuracy * 100, digits = 1))%")
println("Macro-F1: $(round(best_tree.macro_f1 * 100, digits = 1))%")

first(filter(:importance => >(0.0), tree_importance_df), 15)
Best Decision Tree LOOCV: 38/72 (52.8%)
Balanced accuracy: 46.0%
Macro-F1: 45.6%
12×2 DataFrame
Row feature importance
String Float64
1 Dir_0°_H1__q75 0.141063
2 Dir_90°_H0__count 0.126913
3 Dir_90°_H0__q10 0.121096
4 Cubical_H0__entropy 0.119667
5 Cubical_H0__count 0.108628
6 Dir_0°_H1__count 0.108293
7 Dir_22°_H1__max_pers 0.106663
8 Rips_H1__q10 0.104676
9 Dir_90°_H0__max_pers 0.0242658
10 Dir_112°_H1__q90 0.0231103
11 Dir_22°_H0__count 0.00840373
12 Dir_112°_H0__q90 0.00722195
In [24]:
topk = min(12, nrow(tree_importance_df))
top_tree_importance = first(tree_importance_df, topk)

bar(
    top_tree_importance.feature,
    top_tree_importance.importance,
    xlabel = "1D persistence statistics",
    ylabel = "Normalized impurity importance",
    title = "Decision tree feature importance (top $(topk))",
    legend = false,
    xrotation = 45,
    size = (1100, 550),
)

4 Distance matrices

Distances between persistence diagrams

The Wasserstein distance \(W_q\) between two persistence diagrams is the cost of the optimal matching between their points (including matching points to the diagonal, which represents trivial features). With \(q=1\) it equals the Earth Mover’s Distance; with \(q=2\) it penalizes large mismatches more. The Bottleneck distance \(d_B\) is the \(\ell^\infty\) version: it measures the worst single mismatch in the optimal pairing. These distances are metrics on the space of persistence diagrams and are stable with respect to perturbations of the input data.

We compute multiple distance metrics between the persistence diagrams from each filtration:

In [25]:
labels = families

# Rips-based distances (Rips PDs have ~20 intervals, so Wasserstein/Bottleneck are feasible)
D_pi_rips = pairwise_distance([vec(v) for v in pi_rips])
D_betti_rips = pairwise_distance(betti_rips, euclidean)
D_land_rips = pairwise_distance(land1_rips_vecs, euclidean)
D_wass1_rips = wasserstein_distance_matrix(pds_rips, q = 1)
D_wass2_rips = wasserstein_distance_matrix(pds_rips, q = 2)   # Phase 5: Wasserstein-2
D_bott_rips = bottleneck_distance_matrix(pds_rips)

# Directional/radial PDs have hundreds of intervals, so
# Wasserstein/Bottleneck would be prohibitively slow on 72×72 pairs.
# We use only vectorized distances (persistence images, Betti curves, landscapes),
# which are Euclidean distances on fixed-size feature vectors and compute instantly.

# Directional H1 distances (combine all directions via sum of per-direction distances)
D_pi_dir = sum(pairwise_distance([vec(v) for v in pi_directional[name]]) for name in direction_names)
D_betti_dir = sum(pairwise_distance(betti_directional[name], euclidean) for name in direction_names)

# Directional H0 distances (NEW)
D_pi_dir_h0 = sum(pairwise_distance([vec(v) for v in pi_directional_h0[name]]) for name in direction_names)
D_betti_dir_h0 = sum(pairwise_distance(betti_directional_h0[name], euclidean) for name in direction_names)

# Radial distances
D_pi_rad = pairwise_distance([vec(v) for v in pi_radial])
D_betti_rad = pairwise_distance(betti_radial, euclidean)

# Radial H0 distances (NEW)
D_pi_rad_h0 = pairwise_distance([vec(v) for v in pi_radial_h0])
D_betti_rad_h0 = pairwise_distance(betti_radial_h0, euclidean)

# EDT distances (NEW)
D_pi_edt_h1 = pairwise_distance([vec(v) for v in pi_edt_h1])
D_betti_edt_h1 = pairwise_distance(betti_edt_h1, euclidean)
D_pi_edt_h0 = pairwise_distance([vec(v) for v in pi_edt_h0])
D_betti_edt_h0 = pairwise_distance(betti_edt_h0, euclidean)

# Cubical distances (NEW)
D_pi_cub_h1 = pairwise_distance([vec(v) for v in pi_cubical_h1])
D_betti_cub_h1 = pairwise_distance(betti_cubical_h1, euclidean)
D_pi_cub_h0 = pairwise_distance([vec(v) for v in pi_cubical_h0])
D_betti_cub_h0 = pairwise_distance(betti_cubical_h0, euclidean)

distances = Dict(
    # Rips
    "Rips PI" => D_pi_rips,
    "Rips Bottleneck" => D_bott_rips,
    "Rips Wass-1" => D_wass1_rips,
    "Rips Wass-2" => D_wass2_rips,
    "Rips Betti" => D_betti_rips,
    "Rips Landscape" => D_land_rips,
    # Directional H1
    "Directional H1 PI" => D_pi_dir,
    "Directional H1 Betti" => D_betti_dir,
    # Directional H0 (NEW)
    "Directional H0 PI" => D_pi_dir_h0,
    "Directional H0 Betti" => D_betti_dir_h0,
    # Radial
    "Radial H1 PI" => D_pi_rad,
    "Radial H1 Betti" => D_betti_rad,
    "Radial H0 PI" => D_pi_rad_h0,
    "Radial H0 Betti" => D_betti_rad_h0,
    # EDT (NEW)
    "EDT H1 PI" => D_pi_edt_h1,
    "EDT H1 Betti" => D_betti_edt_h1,
    "EDT H0 PI" => D_pi_edt_h0,
    "EDT H0 Betti" => D_betti_edt_h0,
    # Cubical (NEW)
    "Cubical H1 PI" => D_pi_cub_h1,
    "Cubical H1 Betti" => D_betti_cub_h1,
    "Cubical H0 PI" => D_pi_cub_h0,
    "Cubical H0 Betti" => D_betti_cub_h0,
);
In [26]:
p1 = plot_heatmap(D_wass1_rips, individuals, "Rips Wasserstein-1")
p2 = plot_heatmap(D_pi_dir_h0, individuals, "Directional H0 PI")
p3 = plot_heatmap(D_pi_edt_h1, individuals, "EDT H1 PI")
p4 = plot_heatmap(D_pi_cub_h1, individuals, "Cubical H1 PI")
plot(p1, p2, p3, p4, layout = (2, 2), size = (1000, 900))

5 Classification

Leave-one-out cross-validation (LOOCV)

With only 72 samples, we use leave-one-out cross-validation: for each sample, the classifier is trained on all other samples and tested on the held-out one. The accuracy is the fraction of correctly predicted labels across all 72 folds. LOOCV has low bias (nearly the entire dataset is used for training) and is the standard validation strategy for small datasets.

5.1 Distance-based classifiers: k-NN

k-Nearest Neighbors (k-NN)

Given a precomputed distance matrix, k-NN classifies a query point by majority vote among its \(k\) nearest neighbors. Weighted k-NN weights each neighbor’s vote by \(1/d\) (inverse distance), giving closer neighbors more influence. The nearest centroid classifier assigns the query to the class whose average distance to the query is smallest. These are nonparametric methods that work directly with any distance or dissimilarity measure — making them natural for TDA, where we have principled distances between persistence diagrams.

In [27]:
knn_results = []
for (dist_name, D) in distances
    for k in [1, 3, 5]
        r = loocv_knn(D, labels; k = k)
        push!(knn_results, (
            method = "k-NN (k=$k)",
            distance = dist_name,
            n_correct = sum(r.predictions .== labels),
            n_total = length(labels),
            accuracy = r.accuracy
        ))

        r2 = loocv_knn_weighted(D, labels; k = k)
        push!(knn_results, (
            method = "W-kNN (k=$k)",
            distance = dist_name,
            n_correct = sum(r2.predictions .== labels),
            n_total = length(labels),
            accuracy = r2.accuracy
        ))
    end

    r3 = loocv_nearest_centroid(D, labels)
    push!(knn_results, (
        method = "Nearest centroid",
        distance = dist_name,
        n_correct = sum(r3.predictions .== labels),
        n_total = length(labels),
        accuracy = r3.accuracy
    ))
end

knn_df = DataFrame(knn_results)
sort!(knn_df, :accuracy, rev = true)
first(knn_df, 20)
20×5 DataFrame
Row method distance n_correct n_total accuracy
String String Int64 Int64 Float64
1 W-kNN (k=3) Cubical H1 Betti 49 72 0.680556
2 k-NN (k=3) Cubical H1 Betti 48 72 0.666667
3 k-NN (k=1) Rips Betti 48 72 0.666667
4 W-kNN (k=1) Rips Betti 48 72 0.666667
5 k-NN (k=1) Cubical H1 Betti 47 72 0.652778
6 W-kNN (k=1) Cubical H1 Betti 47 72 0.652778
7 W-kNN (k=5) Cubical H1 Betti 47 72 0.652778
8 Nearest centroid Cubical H1 Betti 47 72 0.652778
9 k-NN (k=1) Cubical H1 PI 45 72 0.625
10 W-kNN (k=1) Cubical H1 PI 45 72 0.625
11 k-NN (k=5) Cubical H1 Betti 45 72 0.625
12 k-NN (k=3) Rips Wass-1 45 72 0.625
13 W-kNN (k=3) Rips Wass-1 45 72 0.625
14 k-NN (k=5) Cubical H1 PI 44 72 0.611111
15 k-NN (k=5) Rips Wass-1 44 72 0.611111
16 W-kNN (k=5) Rips Wass-1 44 72 0.611111
17 W-kNN (k=5) Cubical H1 PI 43 72 0.597222
18 k-NN (k=3) Rips Wass-2 42 72 0.583333
19 W-kNN (k=3) Rips Wass-2 42 72 0.583333
20 k-NN (k=3) Rips PI 41 72 0.569444

5.2 Feature-based classifiers

We construct feature matrices by concatenating the vectorized TDA representations from all filtrations:

In [28]:
# Feature matrices at different levels of richness
X_stats_rips = sanitize_feature_matrix(stats_rips)
X_stats_all_h1 = sanitize_feature_matrix(stats_all_h1)
X_stats_all = sanitize_feature_matrix(stats_all)

X_rips_full = build_feature_matrix(
    stats = stats_rips,
    pi = pi_rips,
    betti = betti_rips,
    landscape = land1_rips_vecs,
) |> sanitize_feature_matrix

# Multi-filtration features (H1 only): combine everything
all_pi_h1 = [vcat(vec(pi_rips[i]),
                  vec(pi_radial[i]),
                  [vec(pi_directional[name][i]) for name in direction_names]...,
                  vec(pi_edt_h1[i]),
                  vec(pi_cubical_h1[i]))
             for i in 1:length(families)]

all_betti_h1 = [vcat(betti_rips[i],
                     betti_radial[i],
                     [betti_directional[name][i] for name in direction_names]...,
                     betti_edt_h1[i],
                     betti_cubical_h1[i])
                for i in 1:length(families)]

X_multi_h1 = build_feature_matrix(
    stats = stats_all_h1,
    pi = all_pi_h1,
    betti = all_betti_h1,
) |> sanitize_feature_matrix

# Multi-filtration features (H0 + H1 = comprehensive)
all_pi = [vcat(vec(pi_rips[i]),
               vec(pi_radial[i]), vec(pi_radial_h0[i]),
               [vec(pi_directional[name][i]) for name in direction_names]...,
               [vec(pi_directional_h0[name][i]) for name in direction_names]...,
               vec(pi_edt_h1[i]), vec(pi_edt_h0[i]),
               vec(pi_cubical_h1[i]), vec(pi_cubical_h0[i]))
          for i in 1:length(families)]

all_betti = [vcat(betti_rips[i],
                  betti_radial[i], betti_radial_h0[i],
                  [betti_directional[name][i] for name in direction_names]...,
                  [betti_directional_h0[name][i] for name in direction_names]...,
                  betti_edt_h1[i], betti_edt_h0[i],
                  betti_cubical_h1[i], betti_cubical_h0[i])
             for i in 1:length(families)]

X_multi = build_feature_matrix(
    stats = stats_all,
    pi = all_pi,
    betti = all_betti,
) |> sanitize_feature_matrix

println("Feature dimensions:")
println("  Rips stats only: $(size(X_stats_rips))")
println("  All H1 stats: $(size(X_stats_all_h1))")
println("  All H0+H1 stats: $(size(X_stats_all))")
println("  Rips full (stats+PI+Betti+Land): $(size(X_rips_full))")
println("  Multi-filtration H1 only: $(size(X_multi_h1))")
println("  Multi-filtration H0+H1 (comprehensive): $(size(X_multi))")
Feature dimensions:
  Rips stats only: (72, 11)
  All H1 stats: (72, 132)
  All H0+H1 stats: (72, 253)
  Rips full (stats+PI+Betti+Land): (72, 336)
  Multi-filtration H1 only: (72, 1837)
  Multi-filtration H0+H1 (comprehensive): (72, 3388)

5.2.1 SVM (Support Vector Machine)

Support Vector Machine (SVM)

An SVM finds the hyperplane that maximizes the margin between classes. The RBF (Radial Basis Function) kernel maps data into a high-dimensional space where linear separation becomes possible, controlled by a cost parameter \(C\) (penalty for misclassification). For distance matrices, we convert distances to an RBF-like kernel \(K(i,j) = \exp(-D_{ij}^2 / 2\sigma^2)\) and train a linear SVM on the resulting kernel matrix. This is sometimes called an “empirical kernel map.”

In [29]:
feature_sets = [
    ("Rips stats", X_stats_rips),
    ("All H1 stats", X_stats_all_h1),
    ("All H0+H1 stats", X_stats_all),
    ("Rips full", X_rips_full),
    ("Multi-filtration H1", X_multi_h1),
    ("Multi-filtration H0+H1", X_multi),
]

svm_results = []
for (feat_name, X) in feature_sets
    for kernel in [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear]
        for cost in [0.1, 1.0, 10.0, 100.0]
            kernel_name = kernel == LIBSVM.Kernel.RadialBasis ? "RBF" : "Linear"
            r = loocv_svm(X, labels; kernel = kernel, cost = cost)
            push!(svm_results, (
                method = "SVM ($kernel_name, C=$cost)",
                features = feat_name,
                n_correct = sum(r.predictions .== labels),
                n_total = length(labels),
                accuracy = r.accuracy
            ))
        end
    end
end

svm_df = DataFrame(svm_results)
sort!(svm_df, :accuracy, rev = true)
first(svm_df, 15)
15×5 DataFrame
Row method features n_correct n_total accuracy
String String Int64 Int64 Float64
1 SVM (Linear, C=0.1) Multi-filtration H1 60 72 0.833333
2 SVM (Linear, C=1.0) Multi-filtration H1 60 72 0.833333
3 SVM (Linear, C=10.0) Multi-filtration H1 60 72 0.833333
4 SVM (Linear, C=100.0) Multi-filtration H1 60 72 0.833333
5 SVM (Linear, C=0.1) All H0+H1 stats 56 72 0.777778
6 SVM (Linear, C=1.0) All H0+H1 stats 56 72 0.777778
7 SVM (Linear, C=10.0) All H0+H1 stats 56 72 0.777778
8 SVM (Linear, C=100.0) All H0+H1 stats 56 72 0.777778
9 SVM (Linear, C=0.1) Multi-filtration H0+H1 56 72 0.777778
10 SVM (Linear, C=1.0) Multi-filtration H0+H1 56 72 0.777778
11 SVM (Linear, C=10.0) Multi-filtration H0+H1 56 72 0.777778
12 SVM (Linear, C=100.0) Multi-filtration H0+H1 56 72 0.777778
13 SVM (RBF, C=10.0) All H0+H1 stats 53 72 0.736111
14 SVM (RBF, C=100.0) All H0+H1 stats 53 72 0.736111
15 SVM (Linear, C=0.1) All H1 stats 51 72 0.708333

5.2.2 SVM on distance matrices

In [30]:
svm_dist_results = []
for (dist_name, D) in distances
    for cost in [0.1, 1.0, 10.0, 100.0]
        r = loocv_svm_distance(D, labels; cost = cost)
        push!(svm_dist_results, (
            method = "SVM-dist (C=$cost)",
            distance = dist_name,
            n_correct = sum(r.predictions .== labels),
            n_total = length(labels),
            accuracy = r.accuracy
        ))
    end
end

svm_dist_df = DataFrame(svm_dist_results)
sort!(svm_dist_df, :accuracy, rev = true)
first(svm_dist_df, 10)
10×5 DataFrame
Row method distance n_correct n_total accuracy
String String Int64 Int64 Float64
1 SVM-dist (C=100.0) Cubical H1 Betti 50 72 0.694444
2 SVM-dist (C=10.0) Cubical H1 Betti 48 72 0.666667
3 SVM-dist (C=10.0) Rips Betti 44 72 0.611111
4 SVM-dist (C=100.0) Cubical H0 PI 43 72 0.597222
5 SVM-dist (C=100.0) Rips Betti 43 72 0.597222
6 SVM-dist (C=100.0) Cubical H1 PI 42 72 0.583333
7 SVM-dist (C=1.0) Cubical H1 Betti 42 72 0.583333
8 SVM-dist (C=10.0) Rips Wass-1 41 72 0.569444
9 SVM-dist (C=1.0) Cubical H1 PI 39 72 0.541667
10 SVM-dist (C=1.0) Rips Bottleneck 38 72 0.527778

5.2.3 LDA (Linear Discriminant Analysis)

Linear Discriminant Analysis (LDA)

LDA finds a linear projection of the feature space that maximizes the ratio of between-class variance to within-class variance. The projected data is then classified with a simple 1-NN rule. LDA is a classical method that works well when classes are approximately Gaussian and the number of features is not too large relative to the number of samples. It provides an interpretable low-dimensional embedding.

In [31]:
lda_results = []
for (feat_name, X) in feature_sets
    r = loocv_lda(X, labels)
    push!(lda_results, (
        method = "LDA",
        features = feat_name,
        n_correct = sum(r.predictions .== labels),
        n_total = length(labels),
        accuracy = r.accuracy
    ))
end

lda_df = DataFrame(lda_results)
sort!(lda_df, :accuracy, rev = true)
lda_df
6×5 DataFrame
Row method features n_correct n_total accuracy
String String Int64 Int64 Float64
1 LDA Multi-filtration H0+H1 62 72 0.861111
2 LDA All H0+H1 stats 57 72 0.791667
3 LDA Multi-filtration H1 55 72 0.763889
4 LDA All H1 stats 42 72 0.583333
5 LDA Rips full 39 72 0.541667
6 LDA Rips stats 38 72 0.527778

5.2.4 Random Forest

Random Forest

A Random Forest is an ensemble of decision trees, each trained on a bootstrap sample of the data using a random subset of features. The final prediction is the majority vote across all trees. Random Forests are robust to overfitting, handle high-dimensional features well, and provide built-in feature importance estimates. They are a strong baseline for tabular data classification tasks.

In [32]:
rf_results = []
for (feat_name, X) in feature_sets
    for n_trees in [100, 500]
        r = loocv_random_forest(X, labels; n_trees = n_trees)
        m = classification_metrics(labels, r.predictions)
        push!(rf_results, (
            method = "RF (T=$n_trees)",
            features = feat_name,
            n_correct = sum(r.predictions .== labels),
            n_total = length(labels),
            accuracy = r.accuracy,
            balanced_accuracy = m.balanced_accuracy,
            macro_f1 = m.macro_f1
        ))

        rb = loocv_random_forest_balanced(X, labels; n_trees = n_trees, rng_seed = 20260223)
        mb = classification_metrics(labels, rb.predictions)
        push!(rf_results, (
            method = "Balanced RF (T=$n_trees)",
            features = feat_name,
            n_correct = sum(rb.predictions .== labels),
            n_total = length(labels),
            accuracy = rb.accuracy,
            balanced_accuracy = mb.balanced_accuracy,
            macro_f1 = mb.macro_f1
        ))
    end
end

rf_df = DataFrame(rf_results)
sort!(rf_df, :accuracy, rev = true)
first(rf_df, 12)
12×7 DataFrame
Row method features n_correct n_total accuracy balanced_accuracy macro_f1
String String Int64 Int64 Float64 Float64 Float64
1 Balanced RF (T=100) All H0+H1 stats 58 72 0.805556 0.727652 0.722311
2 Balanced RF (T=500) All H0+H1 stats 56 72 0.777778 0.710985 0.700823
3 RF (T=500) All H1 stats 55 72 0.763889 0.669048 0.640582
4 RF (T=100) All H0+H1 stats 55 72 0.763889 0.665909 0.640369
5 RF (T=500) All H0+H1 stats 55 72 0.763889 0.661742 0.638974
6 RF (T=100) Multi-filtration H1 55 72 0.763889 0.658333 0.623894
7 Balanced RF (T=500) Multi-filtration H0+H1 55 72 0.763889 0.735985 0.723298
8 Balanced RF (T=100) Multi-filtration H1 54 72 0.75 0.710985 0.697298
9 RF (T=500) Multi-filtration H1 54 72 0.75 0.641667 0.618705
10 Balanced RF (T=500) Multi-filtration H1 54 72 0.75 0.715909 0.698095
11 RF (T=100) Multi-filtration H0+H1 54 72 0.75 0.645076 0.620648
12 Balanced RF (T=100) Multi-filtration H0+H1 53 72 0.736111 0.702652 0.68684

6 Combined distance analysis

We combine the best topology-aware distance with a statistics-based distance using convex combinations: \[D_{\text{combined}}(\alpha) = \alpha \cdot D_1^* + (1 - \alpha) \cdot D_2^*\] where \(D_1^*\) and \(D_2^*\) are distances normalized to \([0, 1]\).

In [33]:
stats_for_distance = zscore_normalize(sanitize_feature_matrix(stats_all))
stats_vectors_norm = [stats_for_distance[i, :] for i in axes(stats_for_distance, 1)]
D_stats = pairwise_distance(stats_vectors_norm, euclidean)

# Try combining best Rips distances with stats distance
grid_rips_w1 = combined_distance_grid_search(D_wass1_rips, D_stats, labels)
grid_rips_w2 = combined_distance_grid_search(D_wass2_rips, D_stats, labels)

println("Top 5 combinations (Rips Wass-1 + Stats):")
for r in grid_rips_w1[1:min(5, end)]
    println("  α=$(round(r.alpha, digits=1)), k=$(r.k): $(r.n_correct)/$(length(labels)) ($(round(r.accuracy * 100, digits=1))%)")
end

println("\nTop 5 combinations (Rips Wass-2 + Stats):")
for r in grid_rips_w2[1:min(5, end)]
    println("  α=$(round(r.alpha, digits=1)), k=$(r.k): $(r.n_correct)/$(length(labels)) ($(round(r.accuracy * 100, digits=1))%)")
end
Top 5 combinations (Rips Wass-1 + Stats):
  α=0.6, k=1: 54/72 (75.0%)
  α=0.7, k=1: 54/72 (75.0%)
  α=0.8, k=1: 54/72 (75.0%)
  α=0.3, k=1: 53/72 (73.6%)
  α=0.5, k=1: 52/72 (72.2%)

Top 5 combinations (Rips Wass-2 + Stats):
  α=0.6, k=3: 53/72 (73.6%)
  α=0.7, k=5: 53/72 (73.6%)
  α=0.5, k=1: 51/72 (70.8%)
  α=0.6, k=1: 51/72 (70.8%)
  α=0.8, k=5: 51/72 (70.8%)
In [34]:
# Visualize the grid search
alphas = 0.0:0.1:1.0
ks = [1, 3, 5]

acc_grid_w1 = zeros(length(alphas), length(ks))
for r in grid_rips_w1
    i = findfirst(==(r.alpha), alphas)
    j = findfirst(==(r.k), ks)
    if !isnothing(i) && !isnothing(j)
        acc_grid_w1[i, j] = r.accuracy
    end
end

acc_grid_w2 = zeros(length(alphas), length(ks))
for r in grid_rips_w2
    i = findfirst(==(r.alpha), alphas)
    j = findfirst(==(r.k), ks)
    if !isnothing(i) && !isnothing(j)
        acc_grid_w2[i, j] = r.accuracy
    end
end

p1 = heatmap(string.(ks), string.(collect(alphas)),
        acc_grid_w1,
        xlabel = "k", ylabel = "α (Rips Wass-1 weight)",
        title = "Rips Wass-1 + Stats",
        color = :Blues, clims = (0.3, 1.0))
p2 = heatmap(string.(ks), string.(collect(alphas)),
        acc_grid_w2,
        xlabel = "k", ylabel = "α (Rips Wass-2 weight)",
        title = "Rips Wass-2 + Stats",
        color = :Blues, clims = (0.3, 1.0))
plot(p1, p2, layout = (1, 2), size = (1000, 450))

7 Ensemble classification

Ensemble methods (majority voting)

Ensemble methods combine predictions from multiple classifiers. In majority voting, each classifier casts a vote for its predicted class, and the class with the most votes wins. In weighted voting, each classifier’s vote is weighted by its individual accuracy, giving more influence to better classifiers. Ensembles are more robust than individual classifiers because different methods tend to make different errors.

We combine the best classifiers from each method family:

In [35]:
# Best distance-based k-NN
best_knn_row = knn_df[1, :]
D_best_knn = distances[best_knn_row.distance]
k_best = parse(Int, match(r"k=(\d)", best_knn_row.method)[1])
knn_best = loocv_knn(D_best_knn, labels; k = k_best)

# Best SVM on features
best_svm_row = svm_df[1, :]
best_svm_X = Dict(feat_name => X for (feat_name, X) in feature_sets)[best_svm_row.features]
best_svm_kernel = occursin("RBF", best_svm_row.method) ? LIBSVM.Kernel.RadialBasis : LIBSVM.Kernel.Linear
best_svm_cost = parse(Float64, match(r"C=([\d.]+)", best_svm_row.method)[1])
svm_best = loocv_svm(best_svm_X, labels; kernel = best_svm_kernel, cost = best_svm_cost)

# Best Random Forest
best_rf_row = rf_df[1, :]
best_rf_X = Dict(feat_name => X for (feat_name, X) in feature_sets)[best_rf_row.features]
best_rf_ntrees = parse(Int, match(r"T=(\d+)", best_rf_row.method)[1])
best_rf_balanced = occursin("Balanced RF", best_rf_row.method)
if best_rf_balanced
    rf_best = loocv_random_forest_balanced(best_rf_X, labels; n_trees = best_rf_ntrees, rng_seed = 20260223)
else
    rf_best = loocv_random_forest(best_rf_X, labels; n_trees = best_rf_ntrees)
end

# Best LDA
best_lda_row = lda_df[1, :]
best_lda_X = Dict(feat_name => X for (feat_name, X) in feature_sets)[best_lda_row.features]
lda_best = loocv_lda(best_lda_X, labels)

# Ensemble: majority vote
predictions_list = [knn_best.predictions, svm_best.predictions, rf_best.predictions, lda_best.predictions]
accuracies = [knn_best.accuracy, svm_best.accuracy, rf_best.accuracy, lda_best.accuracy]

ensemble_preds = ensemble_vote(predictions_list)
ensemble_acc = mean(ensemble_preds .== labels)

ensemble_preds_w = ensemble_vote(predictions_list; weights = accuracies)
ensemble_acc_w = mean(ensemble_preds_w .== labels)

println("=== Ensemble Results ===")
println("Individual classifiers:")
println("  k-NN ($(best_knn_row.distance), k=$k_best): $(round(knn_best.accuracy * 100, digits=1))%")
println("  SVM ($(best_svm_row.method), $(best_svm_row.features)): $(round(svm_best.accuracy * 100, digits=1))%")
println("  RF ($(best_rf_row.method), $(best_rf_row.features)): $(round(rf_best.accuracy * 100, digits=1))%")
println("  LDA ($(best_lda_row.features)): $(round(lda_best.accuracy * 100, digits=1))%")
println()
println("Ensemble (majority vote): $(sum(ensemble_preds .== labels))/$(length(labels)) ($(round(ensemble_acc * 100, digits=1))%)")
println("Ensemble (weighted vote): $(sum(ensemble_preds_w .== labels))/$(length(labels)) ($(round(ensemble_acc_w * 100, digits=1))%)")
=== Ensemble Results ===
Individual classifiers:
  k-NN (Cubical H1 Betti, k=3): 66.7%
  SVM (SVM (Linear, C=0.1), Multi-filtration H1): 83.3%
  RF (Balanced RF (T=100), All H0+H1 stats): 80.6%
  LDA (Multi-filtration H0+H1): 86.1%

Ensemble (majority vote): 59/72 (81.9%)
Ensemble (weighted vote): 62/72 (86.1%)

8 Comprehensive comparison

In [36]:
all_results = []

# Distance-based (top 5)
for row in eachrow(first(knn_df, 5))
    push!(all_results, (
        category = "Distance-based",
        method = "$(row.method) [$(row.distance)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# SVM on distances (top 3)
for row in eachrow(first(svm_dist_df, 3))
    push!(all_results, (
        category = "Distance-based",
        method = "$(row.method) [$(row.distance)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# LDA
for row in eachrow(lda_df)
    push!(all_results, (
        category = "Feature-based",
        method = "LDA [$(row.features)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# SVM on features (top 5)
for row in eachrow(first(svm_df, 5))
    push!(all_results, (
        category = "Feature-based",
        method = "$(row.method) [$(row.features)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# Random Forest (top 3)
for row in eachrow(first(rf_df, 3))
    push!(all_results, (
        category = "Feature-based",
        method = "$(row.method) [$(row.features)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# Ensembles
push!(all_results, (category = "Ensemble", method = "Majority vote (4 classifiers)",
    accuracy = ensemble_acc, n_correct = sum(ensemble_preds .== labels), n_total = length(labels)))
push!(all_results, (category = "Ensemble", method = "Weighted vote (4 classifiers)",
    accuracy = ensemble_acc_w, n_correct = sum(ensemble_preds_w .== labels), n_total = length(labels)))

# Combined distances
best_rips_comb_w1 = grid_rips_w1[1]
push!(all_results, (category = "Combined distance",
    method = "Rips Wass-1 + Stats (α=$(round(best_rips_comb_w1.alpha, digits=1)), k=$(best_rips_comb_w1.k))",
    accuracy = best_rips_comb_w1.accuracy, n_correct = best_rips_comb_w1.n_correct, n_total = length(labels)))

best_rips_comb_w2 = grid_rips_w2[1]
push!(all_results, (category = "Combined distance",
    method = "Rips Wass-2 + Stats (α=$(round(best_rips_comb_w2.alpha, digits=1)), k=$(best_rips_comb_w2.k))",
    accuracy = best_rips_comb_w2.accuracy, n_correct = best_rips_comb_w2.n_correct, n_total = length(labels)))

comparison_df = DataFrame(all_results)
sort!(comparison_df, :accuracy, rev = true)
comparison_df
26×5 DataFrame
Row category method accuracy n_correct n_total
String String Float64 Int64 Int64
1 Feature-based LDA [Multi-filtration H0+H1] 0.861111 62 72
2 Ensemble Weighted vote (4 classifiers) 0.861111 62 72
3 Feature-based SVM (Linear, C=0.1) [Multi-filtration H1] 0.833333 60 72
4 Feature-based SVM (Linear, C=1.0) [Multi-filtration H1] 0.833333 60 72
5 Feature-based SVM (Linear, C=10.0) [Multi-filtration H1] 0.833333 60 72
6 Feature-based SVM (Linear, C=100.0) [Multi-filtration H1] 0.833333 60 72
7 Ensemble Majority vote (4 classifiers) 0.819444 59 72
8 Feature-based Balanced RF (T=100) [All H0+H1 stats] 0.805556 58 72
9 Feature-based LDA [All H0+H1 stats] 0.791667 57 72
10 Feature-based SVM (Linear, C=0.1) [All H0+H1 stats] 0.777778 56 72
11 Feature-based Balanced RF (T=500) [All H0+H1 stats] 0.777778 56 72
12 Feature-based LDA [Multi-filtration H1] 0.763889 55 72
13 Feature-based RF (T=500) [All H1 stats] 0.763889 55 72
14 Combined distance Rips Wass-1 + Stats (α=0.6, k=1) 0.75 54 72
15 Combined distance Rips Wass-2 + Stats (α=0.6, k=3) 0.736111 53 72
16 Distance-based SVM-dist (C=100.0) [Cubical H1 Betti] 0.694444 50 72
17 Distance-based W-kNN (k=3) [Cubical H1 Betti] 0.680556 49 72
18 Distance-based k-NN (k=3) [Cubical H1 Betti] 0.666667 48 72
19 Distance-based k-NN (k=1) [Rips Betti] 0.666667 48 72
20 Distance-based W-kNN (k=1) [Rips Betti] 0.666667 48 72
21 Distance-based SVM-dist (C=10.0) [Cubical H1 Betti] 0.666667 48 72
22 Distance-based k-NN (k=1) [Cubical H1 Betti] 0.652778 47 72
23 Distance-based SVM-dist (C=10.0) [Rips Betti] 0.611111 44 72
24 Feature-based LDA [All H1 stats] 0.583333 42 72
25 Feature-based LDA [Rips full] 0.541667 39 72
26 Feature-based LDA [Rips stats] 0.527778 38 72

9 Best classifier evaluation

In [37]:
best_overall = comparison_df[1, :]
println("=== Best Method ===")
println("$(best_overall.category): $(best_overall.method)")
println("Accuracy: $(best_overall.n_correct)/$(best_overall.n_total) ($(round(best_overall.accuracy * 100, digits=1))%)")

ci = wilson_ci(best_overall.n_correct, best_overall.n_total)
println("95% Wilson CI: [$(round(ci.lower * 100, digits=1))%, $(round(ci.upper * 100, digits=1))%]")
=== Best Method ===
Feature-based: LDA [Multi-filtration H0+H1]
Accuracy: 62/72 (86.1%)
95% Wilson CI: [76.3%, 92.3%]

9.1 Confusion matrix

In [38]:
# Use ensemble predictions for confusion matrix
final_preds = ensemble_acc_w >= ensemble_acc ? ensemble_preds_w : ensemble_preds
final_method = ensemble_acc_w >= ensemble_acc ? "Weighted ensemble" : "Majority ensemble"

cm_result = confusion_matrix(labels, final_preds)
classes = cm_result.classes

println("=== Confusion Matrix ($final_method) ===")
println("Per-class accuracy:")
for (i, cls) in enumerate(classes)
    correct = cm_result.matrix[i, i]
    total = sum(cm_result.matrix[i, :])
    println("  $(cls): $(correct)/$(total) ($(round(correct / total * 100, digits=1))%)")
end
=== Confusion Matrix (Weighted ensemble) ===
Per-class accuracy:
  Asilidae: 7/8 (87.5%)
  Bibionidae: 6/6 (100.0%)
  Ceratopogonidae: 8/8 (100.0%)
  Chironomidae: 6/8 (75.0%)
  Pelecorhynchidae: 0/2 (0.0%)
  Rhagionidae: 1/4 (25.0%)
  Sciaridae: 5/6 (83.3%)
  Simuliidae: 7/7 (100.0%)
  Tabanidae: 11/11 (100.0%)
  Tipulidae: 11/12 (91.7%)
In [39]:
heatmap(cm_result.matrix,
        xticks = (1:length(classes), classes),
        yticks = (1:length(classes), classes),
        xlabel = "Predicted", ylabel = "True",
        title = "Confusion Matrix ($final_method)",
        color = :Blues,
        clims = (0, maximum(cm_result.matrix)),
        xrotation = 45, size = (700, 600))

10 Honest evaluation (Nested LOOCV)

The distance-combination nested result is unstable for this dataset. Instead, we perform an honest nested LOOCV for the strongest family of models (Random Forest on statistics): the outer loop holds out one sample and the inner loop tunes RF hyperparameters using only the training fold.

Nested cross-validation

Standard LOOCV can give optimistically biased estimates when hyperparameters are tuned on the same data. Nested LOOCV adds an inner cross-validation loop: for each held-out test sample, the best hyperparameters are selected using only the training fold. This provides an unbiased estimate of generalization performance.

In [40]:
nested_rf = nested_loocv_random_forest(
    X_stats_all, labels;
    n_trees_grid = [200, 500],
    max_depth_grid = [-1],
    min_samples_leaf_grid = [1, 2],
    inner_folds = 4,
    balanced = true,
    rng_seed = 20260223
)
n_correct_nested = sum(nested_rf.predictions .== labels)

println("=== Nested LOOCV Result ===")
println("Model: Balanced Random Forest (All stats)")
println("Accuracy: $(n_correct_nested)/$(length(labels)) ($(round(nested_rf.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_rf.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_rf.macro_f1 * 100, digits=1))%")

ci_nested = wilson_ci(n_correct_nested, length(labels))
println("95% Wilson CI: [$(round(ci_nested.lower * 100, digits=1))%, $(round(ci_nested.upper * 100, digits=1))%]")
=== Nested LOOCV Result ===
Model: Balanced Random Forest (All stats)
Accuracy: 56/72 (77.8%)
Balanced accuracy: 76.2%
Macro-F1: 76.2%
95% Wilson CI: [66.9%, 85.8%]
In [41]:
cm_nested = confusion_matrix(labels, nested_rf.predictions)
classes_nested = cm_nested.classes

println("Per-class accuracy (Nested LOOCV):")
for (i, cls) in enumerate(classes_nested)
    correct = cm_nested.matrix[i, i]
    total = sum(cm_nested.matrix[i, :])
    println("  $(cls): $(correct)/$(total) ($(round(correct / total * 100, digits=1))%)")
end
Per-class accuracy (Nested LOOCV):
  Asilidae: 6/8 (75.0%)
  Bibionidae: 6/6 (100.0%)
  Ceratopogonidae: 6/8 (75.0%)
  Chironomidae: 5/8 (62.5%)
  Pelecorhynchidae: 1/2 (50.0%)
  Rhagionidae: 2/4 (50.0%)
  Sciaridae: 6/6 (100.0%)
  Simuliidae: 7/7 (100.0%)
  Tabanidae: 10/11 (90.9%)
  Tipulidae: 7/12 (58.3%)
In [42]:
heatmap(cm_nested.matrix,
        xticks = (1:length(classes_nested), classes_nested),
        yticks = (1:length(classes_nested), classes_nested),
        xlabel = "Predicted", ylabel = "True",
        title = "Confusion Matrix (Nested LOOCV - Balanced RF)",
        color = :Blues,
        clims = (0, maximum(cm_nested.matrix)),
        xrotation = 45, size = (700, 600))

10.1 Nested LOOCV for Multi-filtration SVM

Why is nested CV needed here?

The Multi-filtration feature matrix X_multi has ~991 features for only 72 samples (a ~14:1 feature-to-sample ratio). In such high-dimensional settings, SVM with RBF kernel can find separating hyperplanes even for random data. Furthermore, selecting the best kernel and cost parameter from many LOOCV runs introduces selection bias: the reported accuracy of the “best” configuration is upward-biased. Nested LOOCV removes this bias by selecting hyperparameters using only the training fold.

We evaluate the Multi-filtration SVM both with and without PCA dimensionality reduction:

In [43]:
# Nested LOOCV for Multi-filtration SVM (no PCA)
nested_svm_multi = nested_loocv_svm(
    X_multi, labels;
    kernels = [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear],
    costs = [0.1, 1.0, 10.0, 100.0],
    use_pca = false,
    inner_folds = 5,
    rng_seed = 20260223
)

println("=== Nested LOOCV: Multi-filtration SVM (no PCA) ===")
n_corr = sum(nested_svm_multi.predictions .== labels)
println("Accuracy: $(n_corr)/$(length(labels)) ($(round(nested_svm_multi.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_svm_multi.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_svm_multi.macro_f1 * 100, digits=1))%")

ci_svm = wilson_ci(n_corr, length(labels))
println("95% Wilson CI: [$(round(ci_svm.lower * 100, digits=1))%, $(round(ci_svm.upper * 100, digits=1))%]")

# Show which hyperparameters were selected in each fold
param_counts = Dict{String, Int}()
for p in nested_svm_multi.params
    key = "$(p.kernel), C=$(p.cost)"
    param_counts[key] = get(param_counts, key, 0) + 1
end
println("\nSelected hyperparameters across folds:")
for (k, v) in sort(collect(param_counts), by=last, rev=true)
    println("  $k: $v/$(length(labels)) folds")
end
=== Nested LOOCV: Multi-filtration SVM (no PCA) ===
Accuracy: 56/72 (77.8%)
Balanced accuracy: 70.8%
Macro-F1: 70.1%
95% Wilson CI: [66.9%, 85.8%]

Selected hyperparameters across folds:
  Linear, C=0.1: 71/72 folds
  RBF, C=10.0: 1/72 folds
In [44]:
# Nested LOOCV for Multi-filtration SVM with PCA (95% variance)
nested_svm_pca = nested_loocv_svm(
    X_multi, labels;
    kernels = [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear],
    costs = [0.1, 1.0, 10.0, 100.0],
    use_pca = true,
    variance_ratio = 0.95,
    inner_folds = 5,
    rng_seed = 20260223
)

println("=== Nested LOOCV: Multi-filtration SVM + PCA (95% var) ===")
n_corr_pca = sum(nested_svm_pca.predictions .== labels)
println("Accuracy: $(n_corr_pca)/$(length(labels)) ($(round(nested_svm_pca.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_svm_pca.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_svm_pca.macro_f1 * 100, digits=1))%")

ci_pca = wilson_ci(n_corr_pca, length(labels))
println("95% Wilson CI: [$(round(ci_pca.lower * 100, digits=1))%, $(round(ci_pca.upper * 100, digits=1))%]")
=== Nested LOOCV: Multi-filtration SVM + PCA (95% var) ===
Accuracy: 53/72 (73.6%)
Balanced accuracy: 68.6%
Macro-F1: 66.7%
95% Wilson CI: [62.4%, 82.4%]

For comparison, a simple PCA + SVM (non-nested) on the Multi-filtration features:

In [45]:
pca_svm_results = []
for kernel in [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear]
    for cost in [1.0, 10.0]
        kernel_name = kernel == LIBSVM.Kernel.RadialBasis ? "RBF" : "Linear"
        r = loocv_svm_pca(X_multi, labels;
                          variance_ratio = 0.95, kernel = kernel, cost = cost)
        push!(pca_svm_results, (
            method = "PCA+SVM ($kernel_name, C=$cost)",
            accuracy = r.accuracy,
            n_correct = sum(r.predictions .== labels),
            n_components = r.median_n_components
        ))
    end
end

pca_df = DataFrame(pca_svm_results)
sort!(pca_df, :accuracy, rev = true)
pca_df
4×4 DataFrame
Row method accuracy n_correct n_components
String Float64 Int64 Int64
1 PCA+SVM (Linear, C=1.0) 0.736111 53 13
2 PCA+SVM (Linear, C=10.0) 0.736111 53 13
3 PCA+SVM (RBF, C=1.0) 0.0833333 6 13
4 PCA+SVM (RBF, C=10.0) 0.0833333 6 13

10.2 Permutation test

Permutation test for feature-based classifiers

A permutation test assesses whether the classifier’s accuracy is significantly better than chance. We shuffle the labels many times, recompute LOOCV accuracy each time, and measure how often the shuffled accuracy matches or exceeds the observed accuracy. If the observed accuracy is far above the null distribution, we can be confident the features contain genuine discriminative signal — even if the absolute accuracy estimate may be optimistically biased.

In [46]:
# Permutation test for Multi-filtration SVM (takes a few minutes)
perm_multi = permutation_test_svm(
    X_multi, labels;
    n_permutations = 500,
    kernel = LIBSVM.Kernel.RadialBasis,
    cost = 10.0
)

println("=== Permutation Test: Multi-filtration SVM (RBF, C=10) ===")
println("Observed LOOCV accuracy: $(round(perm_multi.observed * 100, digits=1))%")
println("Null distribution: mean=$(round(perm_multi.perm_mean * 100, digits=1))%, std=$(round(perm_multi.perm_std * 100, digits=1))%")
println("Max null accuracy: $(round(perm_multi.perm_max * 100, digits=1))%")
println("p-value: $(perm_multi.p_value)")

10.3 Feature selection via RF importance

Why feature selection helps

With ~$(size(X_multi, 2)) features and \(n=72\) samples, overfitting is the main accuracy bottleneck. Selecting the top features by Random Forest impurity importance reduces dimensionality and improves generalization. Critically, feature selection is performed INSIDE each LOOCV fold to avoid data leakage — each fold selects features using only training data.

In [47]:
# RF with feature selection inside each fold (honest evaluation)
for top_k in [20, 30, 50]
    r_sel = loocv_rf_with_selection(
        X_multi, labels;
        n_trees_select = 500, n_trees_classify = 300,
        top_k = top_k, balanced = true, rng_seed = 20260223
    )
    n_corr_sel = sum(r_sel.predictions .== labels)
    println("RF + selection (top_k=$top_k): $(n_corr_sel)/$(length(labels)) ($(round(r_sel.accuracy * 100, digits=1))%)")
    println("  Balanced acc: $(round(r_sel.balanced_accuracy * 100, digits=1))%  Macro-F1: $(round(r_sel.macro_f1 * 100, digits=1))%")
end
RF + selection (top_k=20): 47/72 (65.3%)
  Balanced acc: 58.9%  Macro-F1: 57.2%
RF + selection (top_k=30): 52/72 (72.2%)
  Balanced acc: 66.8%  Macro-F1: 64.3%
RF + selection (top_k=50): 46/72 (63.9%)
  Balanced acc: 58.0%  Macro-F1: 56.2%
In [48]:
# Also try on the H0+H1 stats-only feature matrix (lower dimensional)
for top_k in [20, 30, 50]
    r_sel_stats = loocv_rf_with_selection(
        X_stats_all, labels;
        n_trees_select = 500, n_trees_classify = 300,
        top_k = top_k, balanced = true, rng_seed = 20260223
    )
    n_corr_sel = sum(r_sel_stats.predictions .== labels)
    println("RF + selection on stats (top_k=$top_k): $(n_corr_sel)/$(length(labels)) ($(round(r_sel_stats.accuracy * 100, digits=1))%)")
end
RF + selection on stats (top_k=20): 47/72 (65.3%)
RF + selection on stats (top_k=30): 52/72 (72.2%)
RF + selection on stats (top_k=50): 46/72 (63.9%)

10.4 Nested LOOCV with multi-distance selection

Multi-distance nested LOOCV

With many distance matrices available, we need an honest way to select the best one. The inner loop evaluates all (distance, k) combinations on the training fold; the outer loop provides an unbiased accuracy estimate.

In [49]:
nested_multi_dist = nested_loocv_multi_distance(
    distances, labels;
    ks = [1, 3, 5]
)

n_corr_multi = sum(nested_multi_dist.predictions .== labels)
println("=== Nested LOOCV: Multi-distance selection ===")
println("Accuracy: $(n_corr_multi)/$(length(labels)) ($(round(nested_multi_dist.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_multi_dist.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_multi_dist.macro_f1 * 100, digits=1))%")

# Show which distances were selected most often
dist_selection_counts = Dict{String, Int}()
for p in nested_multi_dist.params
    dist_selection_counts[p.distance] = get(dist_selection_counts, p.distance, 0) + 1
end
println("\nSelected distances across folds:")
for (d, v) in sort(collect(dist_selection_counts), by=last, rev=true)
    println("  $d: $v/$(length(labels)) folds")
end
=== Nested LOOCV: Multi-distance selection ===
Accuracy: 30/72 (41.7%)
Balanced accuracy: 36.6%
Macro-F1: 35.5%

Selected distances across folds:
  Cubical H1 Betti: 50/72 folds
  Rips Betti: 21/72 folds
  Cubical H1 PI: 1/72 folds

10.5 Honest comparison summary

In [50]:
honest_results = []

# Nested RF on comprehensive stats
push!(honest_results, (
    method = "Nested LOOCV: Balanced RF (All H0+H1 stats, $(size(X_stats_all, 2)) features)",
    accuracy = nested_rf.accuracy,
    balanced_accuracy = nested_rf.balanced_accuracy,
    macro_f1 = nested_rf.macro_f1,
    n_correct = sum(nested_rf.predictions .== labels),
    n_total = length(labels),
    honest = "Yes"
))

# Nested SVM on multi-filtration (no PCA)
push!(honest_results, (
    method = "Nested LOOCV: SVM (Multi-filtration H0+H1, $(size(X_multi, 2)) features)",
    accuracy = nested_svm_multi.accuracy,
    balanced_accuracy = nested_svm_multi.balanced_accuracy,
    macro_f1 = nested_svm_multi.macro_f1,
    n_correct = sum(nested_svm_multi.predictions .== labels),
    n_total = length(labels),
    honest = "Yes"
))

# Nested SVM + PCA on multi-filtration
push!(honest_results, (
    method = "Nested LOOCV: SVM + PCA (Multi-filtration H0+H1)",
    accuracy = nested_svm_pca.accuracy,
    balanced_accuracy = nested_svm_pca.balanced_accuracy,
    macro_f1 = nested_svm_pca.macro_f1,
    n_correct = sum(nested_svm_pca.predictions .== labels),
    n_total = length(labels),
    honest = "Yes"
))

# RF with feature selection (best top_k)
for top_k in [20, 30, 50]
    r_sel = loocv_rf_with_selection(
        X_multi, labels;
        n_trees_select = 500, n_trees_classify = 300,
        top_k = top_k, balanced = true, rng_seed = 20260223
    )
    push!(honest_results, (
        method = "RF + Feature Selection (top_k=$top_k, multi H0+H1)",
        accuracy = r_sel.accuracy,
        balanced_accuracy = r_sel.balanced_accuracy,
        macro_f1 = r_sel.macro_f1,
        n_correct = sum(r_sel.predictions .== labels),
        n_total = length(labels),
        honest = "Yes (selection inside fold)"
    ))
end

# Nested multi-distance selection
push!(honest_results, (
    method = "Nested LOOCV: Multi-distance k-NN",
    accuracy = nested_multi_dist.accuracy,
    balanced_accuracy = nested_multi_dist.balanced_accuracy,
    macro_f1 = nested_multi_dist.macro_f1,
    n_correct = n_corr_multi,
    n_total = length(labels),
    honest = "Yes"
))

# Best k-NN on Wasserstein distances (no hyperparameter selection needed for k=1)
for (wname, D_wass) in [("Wass-1", D_wass1_rips), ("Wass-2", D_wass2_rips)]
    r_knn1 = loocv_knn(D_wass, labels; k = 1)
    m_knn1 = classification_metrics(labels, r_knn1.predictions)
    push!(honest_results, (
        method = "1-NN on Rips $(wname) (no tuning)",
        accuracy = r_knn1.accuracy,
        balanced_accuracy = m_knn1.balanced_accuracy,
        macro_f1 = m_knn1.macro_f1,
        n_correct = sum(r_knn1.predictions .== labels),
        n_total = length(labels),
        honest = "Yes (no hyperparams)"
    ))
end

honest_df = DataFrame(honest_results)
sort!(honest_df, :accuracy, rev = true)
honest_df
9×7 DataFrame
Row method accuracy balanced_accuracy macro_f1 n_correct n_total honest
String Float64 Float64 Float64 Int64 Int64 String
1 Nested LOOCV: Balanced RF (All H0+H1 stats, 253 features) 0.777778 0.761742 0.76207 56 72 Yes
2 Nested LOOCV: SVM (Multi-filtration H0+H1, 3388 features) 0.777778 0.707576 0.701216 56 72 Yes
3 Nested LOOCV: SVM + PCA (Multi-filtration H0+H1) 0.736111 0.685985 0.667089 53 72 Yes
4 RF + Feature Selection (top_k=30, multi H0+H1) 0.722222 0.667532 0.64329 52 72 Yes (selection inside fold)
5 RF + Feature Selection (top_k=20, multi H0+H1) 0.652778 0.588961 0.571954 47 72 Yes (selection inside fold)
6 RF + Feature Selection (top_k=50, multi H0+H1) 0.638889 0.580032 0.562345 46 72 Yes (selection inside fold)
7 1-NN on Rips Wass-2 (no tuning) 0.569444 0.476461 0.453422 41 72 Yes (no hyperparams)
8 1-NN on Rips Wass-1 (no tuning) 0.555556 0.491613 0.499287 40 72 Yes (no hyperparams)
9 Nested LOOCV: Multi-distance k-NN 0.416667 0.366017 0.355016 30 72 Yes

11 Discussion

We applied multiple TDA filtration strategies to classify Diptera families from wing venation images. Key findings:

  1. Multiple filtrations are complementary: The Vietoris-Rips filtration on point-cloud samples captures the global loop structure of the wing venation. Directional height filtrations (now 8 directions) encode how topological features are spatially distributed along specific axes, the radial filtration captures the center-to-periphery organization, the EDT filtration captures vein thickness hierarchy, and cubical (grayscale sublevel-set) persistence captures intensity landscape information. Together, these views capture different geometric and topological aspects of the wing.

  2. H0 persistence from directional/radial/EDT filtrations is informative: While H0 is uninformative for Rips on a connected point cloud, H0 from cubical filtrations captures vein branching — when disconnected vein segments merge as the sweep progresses. This is directly related to vein count and branching patterns, a key taxonomic character for Diptera families.

  3. 8 filtration directions improve coverage: Expanding from 4 to 8 directions (every 22.5°) captures oblique vein angles missed previously, providing finer angular resolution of the venation topology.

  4. EDT filtration captures vein thickness: The Euclidean Distance Transform filtration captures the vein thickness hierarchy, which is a diagnostic character (e.g., Tabanidae have thickened C and Sc veins). This provides complementary information to structural topology.

  5. Feature selection reduces overfitting: Random Forest feature importance-based selection (performed inside each CV fold to avoid leakage) reduces the feature-to-sample ratio dramatically, improving generalization in honest nested evaluations.

  6. Multi-distance nested LOOCV provides honest model selection: With many distance matrices available, the nested multi-distance evaluation selects the best (distance, k) combination in the inner loop and provides an unbiased accuracy estimate in the outer loop.

  7. Wasserstein-2 vs Wasserstein-1: Both Wasserstein distances are effective for comparing persistence diagrams, with W-2 penalizing large mismatches more heavily. The comparison reveals whether fine or coarse topological differences are more discriminative.

  8. Statistical rigor: We report LOOCV accuracy with Wilson confidence intervals, nested LOOCV for unbiased evaluation when hyperparameters are tuned, and permutation tests to verify that observed accuracy is significantly above chance level.

11.1 Limitations

  • Class imbalance: Tipulidae has 12 samples while Pelecorhynchidae has only 2, which may affect some classifiers. The filtered analysis (excluding families with < 3 samples) provides a fairer comparison.
  • Image quality and preprocessing parameters (blur, threshold) influence topological features
  • The non-nested LOOCV results for feature-based classifiers are optimistically biased due to hyperparameter selection on the evaluation data. The honest comparison table should be preferred
  • With only 72 samples, confidence intervals remain wide regardless of method

11.2 Future work

  • Extend dataset with more specimens per family, especially underrepresented families
  • Improve imaging/segmentation quality and reevaluate image-based filtrations with less noise sensitivity
  • Apply extended persistence or zigzag persistence for richer invariants
  • XGBoost or gradient boosting classifiers for tabular feature data
  • Deep learning on persistence images or persistence diagrams directly